{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "import glob\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "from os.path import join, basename, isdir\n",
    "from os import listdir, makedirs\n",
    "import shutil as sh\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sort Trials"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "HOME = os.environ[\"HOME\"]\n",
    "os.chdir(join(HOME, 'tmp/<run-name>/'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "dirs = list(filter(lambda d: isdir(d), listdir('.')))\n",
    "dirs = list(filter(lambda d: 'neural' in d, dirs))\n",
    "durations = {}\n",
    "durations['private'] = {'nc': defaultdict(list), 'nn': defaultdict(list)}\n",
    "durations['shared'] = {'nc': defaultdict(list), 'nn': defaultdict(list)}\n",
    "for d in dirs:\n",
    "    t = float(open(join(d, 'duration'), 'r').readline().strip())\n",
    "    preamble = \"shared\" if \"shared\" in d else \"private\"\n",
    "    agents = \"nc\" if \"classic\" in d else \"nn\"\n",
    "    durations[preamble][agents][t].append(d)\n",
    "for k in durations[\"shared\"][\"nc\"]:\n",
    "    durations[\"shared\"][\"nc\"][k].sort()\n",
    "for k in durations[\"shared\"][\"nn\"]:\n",
    "    durations[\"shared\"][\"nn\"][k].sort()\n",
    "for k in durations[\"private\"][\"nc\"]:\n",
    "    durations[\"private\"][\"nc\"][k].sort()\n",
    "for k in durations[\"private\"][\"nn\"]:\n",
    "    durations[\"private\"][\"nn\"][k].sort()\n",
    "print(durations[\"private\"][\"nn\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Trial Outcomes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def success_rate(ms): \n",
    "    if ms.size <= 40:\n",
    "        return sum(ms < 0.1) / 40.\n",
    "    else:\n",
    "        return sum(ms < 0.1) / 66.\n",
    "\n",
    "def do_runtime(dirs, sharedstr, agentstr, timestr):\n",
    "    mins = []\n",
    "    means = []\n",
    "    maxes = []\n",
    "    for d in dirs:\n",
    "        try:\n",
    "            with open(join(d, 'results'), 'r') as f:\n",
    "                lcount, lmean, lmin, lmax, _ = f.readlines()\n",
    "                mins.append(float(lmin.strip().split()[-1]))\n",
    "                means.append(float(lmean.strip().split()[-1]))\n",
    "                maxes.append(float(lmax.strip().split()[-1]))\n",
    "        except:\n",
    "            pass\n",
    "\n",
    "    plt.stem(means)\n",
    "    plt.title(\"{} preamble {}, {} seconds trials\".format(\n",
    "                sharedstr, agentstr, timestr))\n",
    "    plt.show()\n",
    "    print(\"Success rate\", success_rate(np.array(means)))\n",
    "    return mins, means, maxes\n",
    "\n",
    "for sharedstr in (\"shared\", \"private\"):\n",
    "    for agentstr in (\"nn\", \"nc\"):\n",
    "        print(' '.join([\"=====\", sharedstr.upper(), \n",
    "                       agentstr.upper(), \"=====\"]))\n",
    "        for time in durations[sharedstr][agentstr]:\n",
    "            do_runtime(durations[sharedstr][agentstr][time],\n",
    "                       sharedstr, agentstr, time)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Constellations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "for d in durations['private']['nn'][500]:\n",
    "    final_mod = sorted(list(filter(lambda d: 'neural_mod' in d and 'npy' in d, listdir(d))))[-1]\n",
    "    const = np.load(join(d, final_mod))\n",
    "    plt.scatter(const.real, const.imag, label=\"Constellation\")\n",
    "    plt.scatter([np.mean(const).real], [np.mean(const).imag], label=\"Constellation Center\", marker='*')\n",
    "    plt.gca().add_artist(plt.Circle((0,0), 1./np.sqrt(2), fill=False, linestyle='--', label=\"Unit Circle\"))\n",
    "#     plt.gca().add_artist(plt.Rectangle((-1/np.sqrt(2),-1/np.sqrt(2)), 2/np.sqrt(2), 2/np.sqrt(2), fill=False, color='g', linewidth=1.5, label=\"Tx Bound\"))\n",
    "    plt.gca().add_artist(plt.Rectangle((-1,-1), 2, 2, fill=False, color='r', linewidth=1.5, label=\"Tx Bound\"))\n",
    "    plt.title(d)\n",
    "    plt.legend()\n",
    "    plt.xlim([-1.5, 1.5])\n",
    "    plt.ylim([-1.5, 1.5])\n",
    "    plt.gca().set_aspect('equal')\n",
    "    plt.grid()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Trained Symbols vs BER"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "def do_symbols_ber(durations, sharedstr, agentstr):\n",
    "    mins = []\n",
    "    means = []\n",
    "    maxes = []\n",
    "    symbols = []\n",
    "    for t in durations:\n",
    "        for d in durations[t]:\n",
    "            try:\n",
    "                with open(join(d, 'results'), 'r') as f:\n",
    "                    lcount, lmean, lmin, lmax, lsymbs = f.readlines()\n",
    "                    mins.append(float(lmin.strip().split()[-1]))\n",
    "                    means.append(float(lmean.strip().split()[-1]))\n",
    "                    maxes.append(float(lmax.strip().split()[-1]))\n",
    "                    symbols.append(int(lsymbs.strip().split()[-1]))\n",
    "            except:\n",
    "                pass\n",
    "    mins = np.array(mins)\n",
    "    means = np.array(means)\n",
    "    maxes = np.array(maxes)\n",
    "    symbols = np.array(symbols)\n",
    "    if len(means) > 0:\n",
    "        plt.scatter(symbols / 1e6, means)\n",
    "        plt.title(\"{} preamble {} trials\".format(\n",
    "                    sharedstr, agentstr))\n",
    "        plt.xlabel(\"MSymbols\")\n",
    "        plt.ylabel(\"BER\")\n",
    "        plt.show()\n",
    "    return mins, means, maxes, symbols\n",
    "\n",
    "for sharedstr in (\"shared\", \"private\"):\n",
    "    for agentstr in (\"nn\", \"nc\"):\n",
    "        do_symbols_ber(durations[sharedstr][agentstr], \n",
    "                       sharedstr, agentstr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def success_rate(ms): \n",
    "    if ms.size <= 20:\n",
    "        return sum(ms < 0.07) / 20.\n",
    "    if ms.size <= 40:\n",
    "        return sum(ms < 0.07) / 40.\n",
    "    else:\n",
    "        return sum(ms < 0.07) / 66.\n",
    "    \n",
    "\n",
    "# OPTIONAL: get equivalent classic snr\n",
    "snr0 = 10.56335866\n",
    "with open(\"../snr-runs/gain-snr-ber-coeffs.pkl\", \"rb\") as f:\n",
    "    coeffs = pickle.load(f)\n",
    "gain2snr = lambda x: coeffs['gain-snr-linear'][0] * x + coeffs['gain-snr-linear'][0]\n",
    "snr2ber = lambda x: np.exp(coeffs['snr-ber-logquadratic'][0] * x ** 2 + \n",
    "                           coeffs['snr-ber-logquadratic'][1] * x + \n",
    "                           coeffs['snr-ber-logquadratic'][2])\n",
    "def ber2snr(ber):\n",
    "    guess = 10\n",
    "    past = 0\n",
    "    while guess > 0 and guess < 20:\n",
    "        if snr2ber(guess) <= ber:\n",
    "            guess -= 0.01\n",
    "            if past == +1:\n",
    "                break\n",
    "            past = -1\n",
    "        else:\n",
    "            guess += 0.01\n",
    "            if past == -1:\n",
    "                break\n",
    "            past = +1\n",
    "    return guess\n",
    "\n",
    "def success_rate_snr(snrs, threshold):\n",
    "    if snrs.size <= 20:\n",
    "        return sum(snr0 - snrs < threshold) / 20.\n",
    "    if snrs.size <= 40:\n",
    "        return sum(snr0 - snrs < threshold) / 40.\n",
    "    else:\n",
    "        return sum(snr0 - snrs < threshold) / 66.\n",
    "    \n",
    "print(\"3dB BER:\", snr2ber(snr0-3))\n",
    "print(\"5dB BER:\", snr2ber(snr0-5))\n",
    "# OPTIONAL \n",
    "\n",
    "    \n",
    "def do_success_rate(durations, sharedstr, agentstr, threshold=3):\n",
    "    mean_symbs = []\n",
    "    srate = []\n",
    "    for t in durations:\n",
    "        mins = []\n",
    "        means = []\n",
    "        maxes = []\n",
    "        symbols = []\n",
    "        for d in durations[t]:\n",
    "            try:\n",
    "                with open(join(d, 'results'), 'r') as f:\n",
    "                    lcount, lmean, lmin, lmax, lsymbs = f.readlines()\n",
    "                    mins.append(float(lmin.strip().split()[-1]))\n",
    "                    means.append(float(lmean.strip().split()[-1]))\n",
    "                    maxes.append(float(lmax.strip().split()[-1]))\n",
    "                    symbols.append(int(lsymbs.strip().split()[-1]))\n",
    "            except:\n",
    "                pass\n",
    "        mean_symbs.append(np.mean(symbols))\n",
    "        equiv_snrs = np.array([ber2snr(m) for m in means])\n",
    "        srate.append(success_rate_snr(equiv_snrs, threshold))\n",
    "    mean_symbs = np.array(mean_symbs)\n",
    "    srate = np.array(srate)\n",
    "    isort = np.argsort(mean_symbs)\n",
    "    mean_symbs = mean_symbs[isort]\n",
    "    srate = srate[isort]\n",
    "    if len(srate) > 0:\n",
    "        plt.stem(mean_symbs / 1e6, srate * 100)\n",
    "        plt.title(\"{} preamble {} trials\".format(\n",
    "                    sharedstr, agentstr))\n",
    "        plt.xlabel(\"MSymbols\")\n",
    "        plt.ylabel(\"Trials Within 3dB BER (%)\")\n",
    "        plt.xlim([0,2])\n",
    "        plt.ylim([0,100])\n",
    "        plt.show()\n",
    "    return mean_symbs, srate\n",
    "\n",
    "\n",
    "for sharedstr in (\"shared\", \"private\"):\n",
    "    for agentstr in (\"nn\", \"nc\"):\n",
    "        symbs, success3 = do_success_rate(durations[sharedstr][agentstr], \n",
    "                                         sharedstr, agentstr, 3)\n",
    "        symbs, success5 = do_success_rate(durations[sharedstr][agentstr], \n",
    "                                         sharedstr, agentstr, 5)\n",
    "        out = {'symbols_exchanged': list(symbs),\n",
    "               'success_rate_3dB': list(success3),\n",
    "               'success_rate_5dB': list(success5)}\n",
    "        with open(\"{}-{}-seed_convergence.pkl\".format(sharedstr, agentstr), \"wb\") as f:\n",
    "            pickle.dump(out, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Retest Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Copy final mod & demod models\n",
    "makedirs('./trained-models1', exist_ok=True)\n",
    "makedirs('./trained-models2', exist_ok=True)\n",
    "with open(\"trained-models1/dirlist.txt\", 'w') as f:\n",
    "    for shared in \"shared\", \"private\":\n",
    "        for agents in \"nn\", \"nc\":\n",
    "            for t in durations[shared][agents]:\n",
    "                print(shared, agents, t)\n",
    "                for d in durations[shared][agents][t]:\n",
    "                    dirname = d[:-5]\n",
    "                    makedirs(join('./trained-models1', dirname), exist_ok=True)\n",
    "                    makedirs(join('./trained-models2', dirname), exist_ok=True)\n",
    "                    with open(join(\"trained-models1\", dirname, \"shared\"), \"w\") as fsh:\n",
    "                        fsh.write(shared)\n",
    "                    with open(join(\"trained-models1\", dirname, \"mode\"), \"w\") as fmo:\n",
    "                        fmo.write(agents) \n",
    "                    with open(join(\"trained-models1\", dirname, \"duration\"), \"w\") as fdu:\n",
    "                        fdu.write(str(t)) \n",
    "                    with open(join(\"trained-models2\", dirname, \"shared\"), \"w\") as fsh:\n",
    "                        fsh.write(shared)\n",
    "                    with open(join(\"trained-models2\", dirname, \"mode\"), \"w\") as fmo:\n",
    "                        fmo.write(agents) \n",
    "                    with open(join(\"trained-models2\", dirname, \"duration\"), \"w\") as fdu:\n",
    "                        fdu.write(str(t)) \n",
    "                    if 'srn1' in d:\n",
    "                        try:\n",
    "                            sh.copy(glob.glob(join(d, 'mod_*.mdl'))[0], join('trained-models1', dirname, 'mod-weights.mdl'))\n",
    "                            sh.copy(glob.glob(join(d, 'demod_*.mdl'))[0], join('trained-models1', dirname, 'demod-weights.mdl'))\n",
    "                        except IndexError as e:\n",
    "                            print(\"***\", dirname)\n",
    "                            continue\n",
    "                        f.write(dirname + \"\\n\")\n",
    "                    if 'srn2' in d:\n",
    "                        try:\n",
    "                            sh.copy(glob.glob(join(d, 'mod_*.mdl'))[0], join('trained-models2', dirname, 'mod-weights.mdl'))\n",
    "                            sh.copy(glob.glob(join(d, 'demod_*.mdl'))[0], join('trained-models2', dirname, 'demod-weights.mdl'))\n",
    "                        except IndexError:\n",
    "                            pass\n",
    "\n",
    "sh.copy(\"./trained-models1/dirlist.txt\", \"./trained-models2/dirlist.txt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
